[Relay][PRNG] Add uniform distribution generator wrt threefry PRNG#8041
[Relay][PRNG] Add uniform distribution generator wrt threefry PRNG#8041FrozenGene merged 9 commits intoapache:mainfrom
Conversation
2f831ea to
78bc8b7
Compare
|
|
||
| Parameters | ||
| ---------- | ||
| gen : Tensor[10, uint64] |
There was a problem hiding this comment.
This is the ThreefryKeyType introduced in #7083. Please refer to:
tvm/src/relay/op/random/kernel.cc
Line 28 in c999a84
There was a problem hiding this comment.
If so, let us add comment describe what is the meaning of 10.
There was a problem hiding this comment.
You could probably say ThreefryKey instead of Tensor[10, uint64]
| less than high. | ||
|
|
||
| out_shape : Sequence[int] | ||
| Output shape of the random numbers. Product of all dimensions must be a multiple of 4. |
There was a problem hiding this comment.
What is the reason of product must be a multiple of 4?
There was a problem hiding this comment.
It's the property of the threefry key. Please refer to this comment: #7083 (comment)
There was a problem hiding this comment.
Sorry, I just rethink about this problem. There should not be any restriction to the output shape... We could change the input restriction of the threefry_generate in other PR.
There was a problem hiding this comment.
do you mind sending a PR for updating the threefry_generate output, or rather what approach do you have in mind? I tried to avoid this problem by truncating output buffer but this required an extra copy, wonder if you have something else
There was a problem hiding this comment.
@altanh Sorry that I'm not familiar with the threefry algorithm. Is it possible to call _threefry twice in threefry_generate in the following form? something like:
out_array = irb.buffer_ptr(out_array_ptr)
# deal with most of the array
_threefry(irb, tmp, 0, tmp, 4, out_array, 0, out_len // 4)
if out_len % 4 != 0:
# generate remainders in a small tmp buffer
tmp_array = irb.allocate(gen.dtype, 4, name="tmp", scope="global")
# may need to update the tmp key in between
# ...
_threefry(irb, tmp, 0, tmp, 4, tmp_array, 0, out_len // 4)
# only copy the tmp buffer
for i in range(out_len // 4 * 4, out_len):
out_array[i] = tmp_array[i%4]In this way, we coud avoid copying the whole generated tensor.
There was a problem hiding this comment.
Yeah, you could do that. Maybe submit it in a new PR?
There was a problem hiding this comment.
@tkonolige Sure, I will submit one. Could you tell me what kind of update on key tmp we need before the second _threefry? I can only think of updating increment counter (tmp[7]).
There was a problem hiding this comment.
You'll need to update the counter buffer to be equal to out_len
|
|
||
| def test_uniform_infer(): | ||
| oshape = (12,) | ||
| odtype = "float32" |
There was a problem hiding this comment.
Should cover more types. For example float64 you have implemented
|
|
||
| standard_uniform_values = tvm.te.compute(out_shape, lambda *i: uniform_scalar(random_bits(*i))) | ||
|
|
||
| uniform_values = tvm.topi.add(tvm.topi.multiply(standard_uniform_values, high - low), low) |
There was a problem hiding this comment.
How well does this approach work when we have a large range (high - low)? It seems like we would be loosing a lot of potential randomness.
|
Thanks for this PR! I will be reading it soon, and just wanted to point you to a branch I worked on a while ago where I hacked a uniform op + dropout support: https://github.com/altanh/tvm/commits/prng (just in case it might be useful for you to check and compare).
Perhaps this is the operation you're looking for? https://github.com/altanh/tvm/blob/2d9ac7710ab055d4f20e8b5a0a3580836723efac/python/tvm/topi/generic/algorithm.py#L465 Thanks! |
|
@FrozenGene @tkonolige @altanh Thank you for your reviews. I've updated this PR based on them.
@FrozenGene Thank you for the clue! However, I haven't find how to restrict the dtype attributes in
@tkonolige As this approach is only using the fraction bits to represent float, there will be loss of randomness for all floats, at least
@altanh Thank you for your references! The |
Suggest |
@FrozenGene Thank you. I've added the type restriction. |
tkonolige
left a comment
There was a problem hiding this comment.
Looks pretty good to me. Just a couple of minor fixes.
altanh
left a comment
There was a problem hiding this comment.
Overall LGTM with some minor comments. I did want to request that we keep the output shape restriction in the documentation for now until a follow up PR is merged which relaxes it. Thanks for the work!
|
@FrozenGene @altanh @tkonolige I've updated the PR upon the reviews. Could you take another look? Thank you~ |
tkonolige
left a comment
There was a problem hiding this comment.
Looks good to me. Just some small comments
altanh
left a comment
There was a problem hiding this comment.
LGTM! I'm a bit uneasy about introducing a nondeterministic test based on averaging the random numbers but I imagine it will almost never fail. Also left a comment about comparing the min/max of the generated numbers - can we always guarantee <= or >= on the output or will there be some floating point inaccuracy cases where this might be violated?
|
@FrozenGene Could you take another look of this PR? Thank you~ |
|
@FrozenGene Could you have another look at this PR? Thank you! |
|
Thanks @zhuzilin @altanh @tkonolige merged now |
…pache#8041) * Add uniform distribution generator wrt threefry PRNG * fix lint * remove the redundant print * modifications based on review * update docs * update uniform algorithm to use bit operations only * add type restrictions * minor fix upon review * update test and error information
…pache#8041) * Add uniform distribution generator wrt threefry PRNG * fix lint * remove the redundant print * modifications based on review * update docs * update uniform algorithm to use bit operations only * add type restrictions * minor fix upon review * update test and error information
This PR adds a uniform distribution generator using the threefry PRNG introduced in #7083. We would need uniform to develop the training phase dropout as the following roadmap:
The algorithm used is basically the same as the one used in jax: using the random bits generated from
threefry_generateas the fraction section of the float32 or float64. To be specific, I use the last 23 bits of the random bits for float32 and last 52 for float64. There is one different from the jax implementation. In jax, they used a bitcast to turn uint into float:However, as I haven't found the bitcast in te or topi, I use a divide to cast the type, which may be slower:
Thank you for your time on reviewing this PR. I may not be familiar enough with the tvm codebase at the moment, so I'm sorry for breaking any conventions in the community and I'd love to fix them :).
Gently ping @tqchen @altanh @tkonolige