Subtitles section Play video Print subtitles SKYE WANDERMAN-MILNE: I'm Skye, for those who don't know me. I've been working on Control Flow in TensorFlow for quite some time, with the help of [? Sarab ?] and many other individuals on the team. And so my goal with this talk is to tell you everything I know about Control Flow that's important. Let's get started. I'm going to start by going over the lay of the land with Control Flow in TensorFlow. So starting with what I'm going to call the Base APIs, tf dot cond and tf dot while loop. So these are the primitives that are exposed in the public Python API for users to access Control Flow. So you have conditional execution and loops. That's it. So you might be wondering, what about all the other Control Flow functions I know and love, like map or case? These are all built on those two base APIs, cond and while loop. They're sort of wrappers around it that add useful functionality. So diving down into the stack, how are these primitives, cond and while, actually implemented? How are they represented in the graph? So in TensorFlow 1.x, we have these low-level Control Flow ops. You might have heard of them, Exit, Enter, Nextiteration, Switch, and Merge. We'll talk more about these in a bit. There's also an alternate representation. That's what Control Flow version 2 is all about. These are the "functional" ops. And I put "functional" in quotes because it's caused some confusion in the past. It's not like pure functional. In the programming sense, they're still state. But they're higher order functions that take functions as input. So now, the cond branches will be represented as functions. So these sort of do the same thing as the low-level ops, but the higher level functionality is all wrapped up into a single op. Moving back up the stack, you might be wondering what's going to happen with TensorFlow 2.0. If you're using Eager execution, you just write Python and you just use Python Control Flow. So if statements, or loops, or list comprehensions, that kind of thing. So there's no arrow connecting it to this graph mode stuff. But if you use tf dot function, maybe some people have heard of Autograph, which is automatically included in tf dot function, and this attempts to take your eager style, just Python code, and convert it into new Python code that calls the TensorFlow graph APIs. So it's going to try to rewrite all that Python Control Flow, your if statements and while loops, into tf dot cond and tf dot while loop. So note that Autograph is just dealing at this abstraction layer of the public TensorFlow API. It doesn't have to dive down into the low-level ops or anything like that. So that's kind of where we're at. We have the 2.0 world where you just write Python that maybe it can get converted into our public Graph APIs, which in turn are producing these various operators in the graph. And one more thing. Right now, in this new implementation of Control Flow, Control Flow version 2, we are still converting the functional ops back into the low-level ops. This is basically a performance optimization. I hope we don't have to do it in the future. That's why it's this faded-dash arrow. So this talk, we're gonna focus on the base API and how it's implemented. I think there'll be another talk about Autographs, so hopefully they can talk about Control Flow there. Maybe there's also talk about Eager execution and the high-level APIs that are not so complicated. So leave that as an exercise to the viewer. OK. So I'm going to start with going over Control Flow v1, the original low-level representation. You might be asking, why? Why do we care at all? So like I showed in the diagram, we do still convert the functional ops to this representation. So this is basically how it's executed today, always. Furthermore, this is still what we use in TensorFlow 1.x. So all 1.x code is using Control Flow v1. Still very much alive. And I hope it provides a little bit of motivation for why we wanted to implement Control Flow using the functional ops. So I'm going to start with these low-level ops. So up here, Switch and Merge are used for conditional execution, this is tf dot cond. Also in while loops to determine whether we need to keep iterating or we're done. And then Enter, Exit, and Nextiteration are just used while loops to manage the iterations. So let's dive in. So Switch and Merge, these are for conditionals. Let's just start with Switch. The idea is you get your predicate tensor in, this is a Boolean, that tells you which conditional branch you want to take. And then it has a single data input, so [INAUDIBLE] some tensor. And it's just going to forward that data input to one of its two outputs depending on the predicate. So in this picture, the predicate must be false. And so the data's coming out of the false output. Merge basically does the opposite. It takes two inputs, but it only expects data from one of its inputs. And then it just outputs a single output. So Switch is how you start your conditional execution, because it's going to divert that data into one branch. And then Merge brings it back together into your mainline execution. It's not conditional anymore. One implementation detail I'm going to mention here is dead tensors. So you might think that nothing is going to come out of the true output of the Switch, but it actually does output a special dead tensor, which is just like a sentinel value. Like a little tiny thing. And dead tensors flow through the whole untaken conditional branch. And eventually, you're going to get a dead tensor into this Merge. It just ignores it and outputs whatever data tensor it gets. So dead tensors are needed for distributed Control Flow, which I'm actually not going to cover in this talk. Because it's kind of technical and I haven't found it that important to know the details of it. It's covered in Yuan's paper. But I'm mentioning dead tensors because they do show up a lot in the execution. Like, if you look at the executor code, there's all this special case for dead tensors. This is what they're about, it's for conditional execution so we can do distribution. SPEAKER 1: And retval zero doesn't help any. SKYE WANDERMAN-MILNE: Oh, yeah. And that famous error message I want to put on a t-shirt, retval zero does not have a value. That means you're trying to read a dead tensor, or it probably means there's a bug. OK. Moving on to the low-level ops we use for while loops. These manage iterations, basically. The concept you need to know about in execution is frames. So you have one frame per execution. And this is what allows the executor to keep track of multiple iterations, and allows a single op to be run multiple times as you do multiple iterations. So a frame defines a name, which is for the whole while loop. And then it also has an iteration number. So the Enter op, that just establishes a new frame. It means we're starting a new while loop. So it just forwards its input. It's like an identity, except that output is now in this new frame. And it has an attribute that's the frame name, starts at frame 0. Exit's the opposite. It just it's like an identity, except it strips the frame from its input. So output is now not in that frame anymore. And these can be stacked. So if you have a bunch of Enters on a bunch of frames, you have a bunch of Exits, it'll pop them off one at the time. The Nextiteration's just the final piece in order to increment that iteration count. This might make more sense when we put it all together, so let's do that. Starting with tf cond again. Let's just work through this. So down here, you have the API call that we're using. So we start, we have this predicate. Note that the predicate isn't actually part of the cond. It happens outside here, but then we feed it into the Switch operators. So the Switches and Merges mark the boundary of the conditional execution, remember. So we'll feed this predicate and then, the true branch is an Add. So we have a Switch for each input, for x and z, which is the external tensors we use in that branch. You'll note that they are only being emitted from the true side of it. So if the false branch is taken, nothing's connected to that. That comes out of Add, then similarly on the other side, we're Squaring y, so we have a Switch for the y. This time, it's going to be emitted from the false branch into the Square. And then, we only have one output from this cond so we have a single Merge. Either the Square or the Add, only one of those is going to actually have data, and that's what will be output. So note that there is a Switch for each input and a Merge for each output, they don't have to match. And in this example, the two branches are using disjoint tensors. But say, we did the Square of x instead of y, then you would have an edge from both the true output and the false output, depending. Go to the Add or the Square. Let's quickly, actually, go over the while loop API, just to make sure we all remember. So the first argument, is a function. That's the predicate function. The second function is the body that we're going to execute. And this is where it's kind of interesting. So you have some inputs, these are called the loop variables, the input to the while loop. And then it's going to output updated versions of those same loop variables. So the inputs of the body match the outputs of the body. Like, same number-type shape of tensors because they're just the updated variables. SPEAKER 2: Can't the shape-type [INAUDIBLE] SKYE WANDERMAN-MILNE: The shape can change, you're right. Same number and types. And then the final, we'd provide some initial input to start it off. So that's the 0, the final argument. And then the output is going to be whatever the final value of the loop variables are. And then the predicate function takes those same loop variables as input but just outputs a Boolean, like, do we continue execution or not? So now we'll start with the inter-node. This, remember, establishes the new frame. We're starting a new while loop. I guess it's called L for loop. We go through a Merge now, kind of reversed from the cond where you start with the Switch. Now you start with a Merge. Because it's choosing is this the initial value or is this the new, updated value from an iteration? That feeds into the predicate. Note that the predicate is inside the while loop now because it has to execute multiple times. The output goes into the Switch node to choose whether if it's false, and we're going to exit the while loop with that exit node. Otherwise, we go into the body, which is an Add in this case, take the output of the body, feed it to the next iteration. Because we have to bump that frame count, remember? And then feed it back into the Merge, which will forward it back again and again, until we get to the Exit. So, hopefully, this kind of makes sense. You can see there's a loop in there. That's the while loop. SPEAKER 3: For sequential ones, how does the Merge know to select the z or [INAUDIBLE]? Because wouldn't neither of them be dead tensors at that point? SKYE WANDERMAN-MILNE: I don't know the details of how this is implemented. But I think because the frame is different, z only is in the first frame. Because each frame is conceptually like you made a copy of the body, it's going to keep track of different pending counts for each node in the body, or the Merge, or the Switch. So I think that's why. OK. All right, so that's all I'm going to go over with Control Flow v1. It does have some advantages. It all, kind of, falls out of the fact that these low-level operators are designed to naturally fit within the dataflow model, because data graphs are dataflow graphs. So you get nice features like pruning, works pretty naturally, because it's all regular nodes, sort of, for pruning. You can have parallel execution of while loop iterations, which is actually pretty cool, I think. Because once you add in this frames logic, it kind of naturally keeps track of all the pending counts. It runs just like a regular-- like, if you unrolled the loop and the data will flow through as far as it can. Ops will be executed as soon as they can. It just kind of works. However, there are some disadvantages. It's very complicated. Like, you can see that this is a bunch of nodes to express what in most programming languages is like one line, like while. This shows up especially in gradients and nested Control Flow. You end up with all these crazy edge cases where you didn't hook up the inner Merges correctly or whatever. As a result of this complexity, higher order derivatives are not implemented. This is not like a design problem, per se. It's just it's so complicated and there's so many edge cases no one has been able to do it, or has wanted to do it. Similarly to graph construction being complicated, the runtime is complicated. Because you have to have all this dead tensor logic, all this firm logic, and it's very intricately baked into the executor. And this makes it hard to read and maintain, and also, adds performance overhead. It's hard for other downstream things to analyze and make sense of. An example of this is [INAUDIBLE] has been trying to do [? auto ?] clustering for XLA, and so he has like whole docs written on how to handle dead tensors, because they can show up anywhere. Similarly, XLA actually represents Control Flow in a functional way if in while ops. So when they consume TensorFlow graphs, they have to pattern-match this crazy stuff back into just the while op that originally produced it. And especially with gradients and nested Control Flow, it gets very complicated. There is a number of edge cases. This was actually one of the main motivations for building Control Flow v2. Because we were fixing so many bugs and how this was represented in so many edge cases, that it's like, we just need a simpler representation. OK. So, hopefully, this will be simpler. I can fit it on one slide for both. [LAUGHTER] So tf dot cond, it's just an if op now. You have the Boolean predicate coming in. These arrows represent the type signature of the op, not individual tensors per se. So then this could be any number and type of tensors coming into input. And then similarly, any number of type tensor is coming out. They don't have to match. Then these represent, they're technically function attributes, but they're basically functions attached to this op representing the true branch and the false branch. So they're like, little subgraphs. One thing to note that's important with these functions is that the function signatures have to match. So the functions have the same inputs and the same outputs. The inputs and outputs don't have to match, what but they have to match across the two branches. SPEAKER 4: [INAUDIBLE] the type, not values? SKYE WANDERMAN-MILNE: Yes. Sorry. Well, we're just talking signatures right now. So just type and possibly shape in some cases. Yeah, it doesn't even have to be implemented this way, but it is. It makes somethings simpler to think about. But keep that in mind. Similarly, tf dot while loop just turns into a while op now. Now all our inputs and outputs are just the loop variables. Because, remember, the predicate takes those loop variables as inputs. So you have a cond function or a predicate function, takes a loop verbals as input, output, or Bool. And then the body function that takes the loop variable inputs and outputs, the updated version, which will eventually be-- the final value will be updated output from the op. So does this make sense? This picture. SPEAKER 4: One thing to clarify is, in tf cond it doesn't have, actually, any concept of variables in the higher level API. So this is things we capture and we take care of making sure they match. So from the user's point of view, they don't have to do anything special. SKYE WANDERMAN-MILNE: Right. That's, kind of, like the while op very closely matches the TensorFlow semantics. But the if op is a little bit different. They have to match [INAUDIBLE] inputs at all, because we do it through closures and API. That's like, you do it within your code. So if this is good for everyone, I'm going to move on to going over gradients. I'm going over how gradients work in Control Flow v2. It is somewhat general. It's much simpler to think about with the functional ops. So let's start at a high level. Just conceptually, what is the gradient of a cond? It's basically, just another cond. And you take the same predicate, and you take the gradient of both sides. So if we took the forward true branch, then we want to take the gradient of the true branch on the way back. Make sense? Hopefully, this is good. While loops, a little bit more complicated, not too bad. So say we have this forward while loop, you have your cond and body functions. Just assume it executes end times for now, we just know. So now the gradient, we have to execute the gradient of the body function N times. Like we just have to do the reverse. Imagine an unrolled loop, we did N invocations of the body. Now we're going to do N invocations of the gradient of the body. And you pass in the grad y's or cotangents or whatever. And those are your loop variables. Then your predicate, now, is just this counter to make sure we execute N times. So, hopefully, this makes sense. The big question is, how do we know what N is? The answer is that, at least in Control Flow v2, we just add a little counter to every a while loop. That just outputs the total number of iterations. And we don't return this to the user, but we can wire it through to the gradient when we need it. Does this make sense at a high level? We're going to dive into the details. But this is concept. OK. So I'm about to go into more concrete examples. And I'm also going to discuss the tricky part about gradients, which is intermediate values. Basically, when you have a data dependency from the forward pass to the backwards pass. So start with cond. Here is a similar diagram. I rearranged it to make it fit nicer. But one important thing to notice is that now the arrows are actual tensors. They're not just type signatures anymore. So the predicate is a Boolean. In this example, there's only one input and one output, maybe they're different types, who knows. Doesn't matter for this example. And then you have the true and false functions with the same types. OK. So here's the gradient function. It's just another if. This time we're dealing with the cotangents instead of the initial forward values. And we have the gradient of the true function and the gradient of the false function. Looks good so far. Hopefully. If there was no data dependencies between the forward and backward pass, like if you're doing y equals x plus 1, this is all you need. But what if somewhere in the forward pass, let's say the true function, there's an op? And we need to use its output in the backwards pass? So this is conceptually what we need to do. We need z in the gradient function. This is a problem, because you can't just have an edge between two function definitions. You need to have inputs and outputs. Like, they need to go-- The If ops need to be attached to each other with an edge. This doesn't make sense by itself. So we're, basically, going to do just that. We're going to make inputs and outputs. We're going to add them to the if op. So let's do that. So we're going to output z from true function. And then similarly, add it as an output from the if op, because the if op is calling true function. And then we're going to add it as an input to the gradient if op. And add it as an input to the gradient true function. OK, there's still one problem, though. And that's that now the true and false branches of both if op don't match anymore. We need them to have the same signature. So let's just add some inputs and outputs. Starting on the gradient side, this is fine. We can just add z as an input to the false function. It's just going to ignore it, it's an unused input. But on the forward pass, this is a problem. Because we need to add z as an output to the false function, but we don't actually have anything to output. It's like, what is this question mark op? And it needs to be the same type, and possibly shape, if we want to keep a strong shape, or a fully known shape. And we might not know the shape until runtime. So what we do? I had to think about this for a long time and came up with many different solutions. And I partially implemented all of them before coming up with using Optionals. Optionals are maybe types. You've heard of that? It's a special kind of tensor that can hold another tensor inside of it or not. So it's just a wrapper that may or may not have another tensor inside of it. And it's also a tensor. It's like a variant tensor. So the true function is going to return an Optional with the z value inside of it. The false function is going to return an Optional with no value inside of it. OK, great. Now they're the same type, Optional. Could have the same thing inside them. In a gradient true function, we can unwrap that Optional to get the raw z value. And then the false function still just ignores it, which is great, because there's nothing inside of it. I didn't know how to draw this, but that's what we do. So all the intermediate values that are needed by the grading computation are added as Optional outputs of the forward pass. Does this make sense to everyone? That's it for cond gradients. SPEAKER 3: Conceptually, what's the difference between doing this and the dead tensor stuff? SKYE WANDERMAN-MILNE: Oh. Yeah. Great question. I meant to go over that, so thank you for asking. At a high level, this is just how it works in Control Flow v1. The gradient if cond is another cond. You can express that into low-level ops. But the dead tensors are the big difference. So v1 was, kind of, using dead tensors instead of Optionals. And you would just have that edge because there's no functions [INAUDIBLE].. You could just draw that edge between the forward and backward pass. And if it's the untaken branch, you'll have a dead tensor flowing across that edge. There's none of this matching business, you just draw the edge. SPEAKER 3: The interesting thing with the Optional is that it tells you in the type of it that it might be that. Where in the dead tensor you had no such information around. SKYE WANDERMAN-MILNE: Right. SPEAKER 3: So someone like [INAUDIBLE] doesn't have to spend as much time reverse engineering. [INAUDIBLE] exactly what it was meant to do complicated cases. So now what tensors might be dead or not? SPEAKER 3: So this is, essentially, a much more explicit way of making it clear what it be done versus what might now. SKYE WANDERMAN-MILNE: It's kind of like, more complicated. Like, this was actually simpler in Control Flow v2, because you're just like, draw the edge, and the executor will take care of all this dead tensor stuff. Yeah, it made the whole system more complicated as a whole to support that. OK, so let's move on to while gradients. So again, we're dealing, now, with concrete tensors. So input x, output y. They have the same type but they are different values. The body function-- note that I used xi because it's run multiple times. And each time it takes, it might be x or it might be an intermediate value and outputs the updated value of y of i. Then I drew the cond function small. And I didn't draw as inputs and outputs, because they don't really matter that much for the gradient, but they're there. It does have them. So same thing for the gradient. Very similar to the cond case, now we're dealing with the cotangents. Hoping this makes sense. We took the gradient of the body and we're running N times. I forgot to draw N, too, but it's there. Same scenario. Oh, no. What are we going to do? We can't just draw this edge between the two function definitions. So this time, we don't have to worry about the matching thing anymore. Thank goodness. We'll add the input to the grad body function and the grad cond function, but that's fine because we can ignore inputs. But we have a new problem, which is that there's actually multiple values of z. Because the body function is going to execute multiple times, there's no guarantee that this op that outputs z is going to output the same value on every iteration. So we actually have to output all the values of z from the forward pass, and we don't know how many that will be until we run it and take them as input to the gradient function. So we use stacks, otherwise known as accumulators in the code sometimes. So we're going to start with an empty-- we use tensor lists, which are kind of like tensor arrays, but not stateful. You can see in these little function signatures, we're going to start with an empty tensor list that we pass through the while. And then in the forward pass, we're going to push values onto that stack, or that list. And since it's stateless, you take the list in as input and the value you want to add to it and it, conceptually, returns you a new list that has that new element added to it. Under the hood it doesn't actually have to make all these copies, I hope. Similarly in the backwards. So then we're going to keep pushing values, outputting these new lists, and keep pushing to them until we get the full list with all the values in it. That's output from the while loop. Actually, I have a picture for this. So I guess the point is that, in the backwards pass you just pop, opposite of push, to get the value out again. And so, this is a little bit complicated. But you start with the empty list as input, now these lists are actually loop variables. So the stateless tensor list works quite nicely with this, because the loop variable is going to have whatever has accumulated so far as input to the body function. And it adds the new z and outputs that as the updated version. And so the final list is going to be the full list, which you pass into the gradient function. It's going to do the same thing, except popping to pass, to get that raw value of z. And then finally, the list should be empty at the end. And then, since it's a loop variable, we end up outputting an empty list, but we don't actually need that output. That's just how it works. SPEAKER 2: I have a question. SKYE WANDERMAN-MILNE: Yeah. SPEAKER 2: Are you saying the gradient values always [INAUDIBLE]? SKYE WANDERMAN-MILNE: It's only when you when you need them. SPEAKER 2: It's just always [INAUDIBLE].. OK. Thank you. SKYE WANDERMAN-MILNE: Yeah. That's a good question. SPEAKER 4: Now you could [INAUDIBLE] in the normal TensorFlow graph probably is able to remove them. SKYE WANDERMAN-MILNE: Yeah, that's the way it actually used to do. Although, that's a little weird through functions so we changed it. SPEAKER 3: Another question. Does this imply that in your while loop, your memory consumption is, basically, linear in the number of variations you go through? SKYE WANDERMAN-MILNE: Yeah, if you have a gradient like this. That's some future work. I would love to see doing re-materialization, or check-pointing, I think it's called in the literature. But we don't do that. SPEAKER 2: Can explain again, in the if, why can't you draw a line just from the original-- SKYE WANDERMAN-MILNE: Oh, yeah. The blue boxes are function definition. And then the while op is going to call that function many times. So it's sort of like, if you're writing two functions in Python and they're not nested or anything, they're just side by side. You can't take an intermediate variable from one function and use it in another one. It's going to be like, I don't know what this is. You have to output it then have it as input to the other function. Or at least in TensorFlow we don't have closures or anything fancy like that. So that's how we do it. Does that make sense? SPEAKER 2: Kind of. SPEAKER 3: The value for a particular execution of a function of particular intermediate value of a particular function execution doesn't have a name that can be addressed in order-- And if it had a name, it would greatly complicate the lifetime issues. We wouldn't be able to [INAUDIBLE] intermediate [INAUDIBLE] functions. SKYE WANDERMAN-MILNE: Or maybe another way is that these function definitions aren't actually in the graph. I draw them as if they are, but they're off to the side. All you see are the while ops. And then when you call the function, then you see that. But you only see it for that call. So it's like this z op in here doesn't exist out here in the main graph where this gradient while op can see it, or in this other function definition. Oh, and to compare to Control Flow v1 again, same general idea. These while ops could be the whole mess of low-level ops and, due to some true while loops, represent it that way. The big difference, this time, is in the stacks. They use the old resource back tensor arrays, which were stateful. SPEAKER 4: We actually use the resource [INAUDIBLE] stack. SKYE WANDERMAN-MILNE: Oh, you're right. You're right. SPEAKER 4: Separate nests. SKYE WANDERMAN-MILNE: OK, yeah. But they were stateful, is the point. So they were actually just inputs. They weren't outputs. And you just modify that state. One big disadvantage of this was that you couldn't take higher-order derivatives because you would exhaust the stack once, and it's stateful and you can't get it back anymore. Whereas these, it's this full list. Because it's a stateless thing, I can pass it to another while op, no problem. So coming back to Control Flow v2. Let's recap what's good and bad about it. So now we can take higher-order derivatives because it's very simple. The gradient code, when it's looking at an if op, it doesn't know if that if op was actually the first derivative of some other if op. They're are all the same. Inputs and outputs just are normal. It's much easier to convert to the XLA if and while ops and downstream TPU integration. Graph construction logic, I hope is simpler. Take a look for yourself. So besides being easier to maintain, this lets us give better error messages, and hopefully there'll be fewer bugs. OK. So now assuming that we just run the functional ops, even though I said we don't, assume we do. The execution could be much simpler, because we don't have dead tensors or because we use Optionals now. And we don't have frames because it's managed by the while op. But the disadvantage of running these ops is that they aren't as performant for a number of reasons listed there. So we could fix this with the functional ops. And it would make sense to do this because a lot of these also apply to just regular function calls, which are kind of a big deal now. But for now, we decided to just take the functional op. So right before you run it-- so you've already constructed the graph, you're ready to run it-- we're going to convert it back into the old low-level representation. So now we get rid of the disadvantages because we're, hopefully, just running the same thing. But we also don't get our simpler execution because we're still running the old thing. So we call this lowering, because they're sort of lowering to this more low-level form. This was, basically, a staging trick so that we can do all the graph construction stuff, which is taking quite some time, without having to worry about the execution as much. Because there were still some issues. It's very similar to function in-lining. An if op and a while op are kind of very fancy function calls. And so this is how you in-line them, with these low-level level dataflow operators. And so it runs with in-lining before anything else happens, and this is so we can take advantage of any downstream optimization or placement or whatever. In the case of Control Flow, we want it to work the same as it did before in Control Flow v1. And I think Eugene is fixing this all up, so this is actually true now. As of, like, last week. SPEAKER 5: So this converting will be removed eventually? SKYE WANDERMAN-MILNE: I would love to see it removed. Oh, yeah. So right now we in-line everything, including function calls, because similar story for functions, it makes a lot of things easier. I hope that we don't depend on this forever. That we sort of do try to make it so function calls are just as performant and as good not in-line. Because it's the same for Control Flow. If we always assume everything's in-line, then we're never going to be able to get our simpler execution and just run the functional ops. Because they're very, very similar function calls, they have the same problems. So if you fix it for functions it's not a huge step to, then, fix it for Control Flow. Where are we at with Control Flow v2? It's still under development. There's bugs and features that need to be implemented. But it's basically on in tf 2.0, if you're using pure 2.0 code. So remember Eager, doing his own thing, just use Python. And then, Control Flow v2 is always on in tf dot functions. There's no way to get old Control Flow. If you want to run new Control Flow in either 1.x code or you're using a compact dot v1 dot graph, those still use the old Control Flow, you can use this environment variable to turn it on. So now when people ping me in and are like, I have this horrible Control Flow bug. I'm like, try the environment variable. And sometimes it fixes it. Or sometimes it at least gives an easier to debug error message. Unfortunately, I would love to have realized the glorious future, where it's all new Control Flow. Old Control Flow doesn't exist. We can delete that code. I don't know if it makes sense to do the work to make it so we can turn it on in 1.x code because there's a few big blockers. Namely, functions don't work with ref variables. And so by extension, these functional ops don't work with ref variables. That would be a lot of work to implement. And the question that you asked about how we add the gradient outputs when you request a gradient, only when they're needed, which it will only know after you build the gradient graph and see what incoming edges you have. This actually breaks sessions. Sessions do not like it when you add inputs and outputs to ops. And will potentially make your session unusable. You'll have to make a new session. So in 2.0 we don't have sessions, great. But in 1.x we definitely have sessions. Another little note. In addition to Control Flow V2, there's a new effort to re-implement tensor arrays. And I sort of hinted at this by incorrectly stating the old tensor array as stacks but it's the same idea. Tensor arrays were these resource back stateful things. Now we're going to make tensor arrays. It's still the same API, so nothing should change for the user, but under the hood, we're going to use immutable tensor lists, which are variants instead of resources. And so you get higher-order derivatives, it's easier to reason about something that's dataflow style instead of stateful in our dataflow graphs. It's nicer. And then in particular, an area of active development is that we do need to make these new tensor arrays work in XLA. So this is kind of annoying, because we've kept saying, oh, the new Control Flow [INAUDIBLE],, it's going to make XLA so easy. It's just going to work. But we do have to implement this one thing. [? Sarab's ?] working on this. I think it's almost there. We'll see. SPEAKER 4: Getting there. Question. So is it true that TensorFlow [INAUDIBLE] where you only use the [INAUDIBLE]?? SKYE WANDERMAN-MILNE: Yes. Yeah, so it's maybe theoretically different from Control Flow, because it's tensor arrays. But tensor arrays are so tightly linked to Control Flow. And we only support the new tensor arrays in new Control Flow because we don't want to deal with the stateful thing. SPEAKER 2: You don't know what tensor array is. Usually when you do Control Flow and it models, you have something like an RNN, that computes something for [INAUDIBLE]. And you often want to take a single tensor that represents the results of all time steps together. And tensor array is the data structure that lets you do that. SKYE WANDERMAN-MILNE: Yeah. I don't think there's too much use for tensor array outside of while loops, I'm sure I would stand corrected if I looked into it. So these are some details on what's going on here. That's all I have. I'm going to end on this slide so you can look at the beautiful picture. And I guess we have plenty of time for questions. So what was your Control Flow v1 question? SPEAKER 3: How does it work with the branches [INAUDIBLE]?? SKYE WANDERMAN-MILNE: Oh, good question. So this is when you have a tf dot cond, remember just takes lambdas and captures everything by closure. So you could just not close over anything. Like, return one or two. SPEAKER 1: Or like, it's a sourceless op like [INAUDIBLE].. SKYE WANDERMAN-MILNE: Yeah. It uses the predicate. It wires together all the dataflow using the predicate. And in particular, you can also have a cond that doesn't return anything, it just has side effects. And I think in Control Flow v1, it will return to predicate value. I thinl it does that in Control Flow v2 because I wanted to test the pass in both cases. But it's a little arbitrary. SPEAKER 4: So the way to do this is you have ops that have a control dependency on something that depends on the Switch. Because [INAUDIBLE] propagates through [INAUDIBLE] as well. So this is how it's actually implemented in Control Flow v1. SKYE WANDERMAN-MILNE: Yeah. SPEAKER 1: Well, it can't depend on the Switch. It has to depend on like one output. SPEAKER 4: Yeah. So you have a Switch of the predicate. And on each side of that [INAUDIBLE] that takes the predicate twice. Then you have an identity op on each branch. And every op that's inside one of the Switch branches has a control dependency on that corresponding identity. So because, then, this propagates through control edges, it makes things work. SPEAKER 1: That makes sense. SKYE WANDERMAN-MILNE: That's a part of why we were able to do [INAUDIBLE]. There's a lot of storage. Yeah? SPEAKER 2: So when you described the graph modification for taking gradients of if, when does this modification happen? Does it happen when you construct the if op or when you're taking gradients? SKYE WANDERMAN-MILNE: Great question. It happens when you take the gradient. SPEAKER 2: The gradient. So for those-- SPEAKER 3: Does that depend on whether you're using tape gradients or tf dot gradients? SKYE WANDERMAN-MILNE: No. SPEAKER 2: We could [INAUDIBLE] early if you're doing tape gradients. We currently do not. SKYE WANDERMAN-MILNE: Yeah. SPEAKER 4: So that means for those function arguments, or functional attributes, you cannot draw lines between two, but you can modify one. SKYE WANDERMAN-MILNE: Yeah, you can modify them to add inputs and outputs, which you're not really supposed to do with sessions. But we do it. The reason we do it when you request a gradient is that, a, if you never take the gradient we don't want to add extra stuff, although it could get pruned. SPEAKER 4: You want to look [INAUDIBLE].. SKYE WANDERMAN-MILNE: It makes your graph look nice at least, to not have all the extra outputs. And also, you don't know which intermediates you're going to need until you build the gradient graph. So if we did it with the tape, we could say, oh, presumably because you're running with a tape, you are going to want to take the gradient at some point. SPEAKER 4: We can actually ask the tape if the tape is going to integrate into one of those outputs. We can't answer their questions. SKYE WANDERMAN-MILNE: So then we could proactively create the gradient at the same time as you create the forward pass and add the outputs there, all at once. But since we have the two code pass, we just do it the same in a two code pass. Because with tf doc gradients, you have no idea if you're gonna call it or not until it happens. That's a good question. Functions work the same way too, because they have like a similar-- if you just have a function call, you'll have the same thing with intermediates and you'll have to add inputs and outputs. So we're back in Control Flow v1, right? This is what it looks like, this stuff. What if you want to run your branch functions or your body or whatever on multiple devices? So I don't totally understand this myself. It's going to be brief. Cond, it's pretty simple. You just do it like normal, I guess. You add the sends and receives, dead tensors can flow through these. So this is why you need the dead tensors. Because for the untaken branch, you basically need to tell other device, this isn't taken. Stop waiting for inputs on this. So you can shut down or whatever. SPEAKER 4: Another, we could have chosen to send the predicate instead. But was a simple modification of the existing TensorFlow that had a huge cost. If I had chosen to send the predicate, we wouldn't need so much of that tensor propagation and all the bugs associated with it. SKYE WANDERMAN-MILNE: Dead tensors are kind of crazy. In really big graphs, you will spend time just propagating all the dead tensors, and send data across the network, or whatever. It's one of those things. We added all this stuff and now this is very conceptually simple. You just add the send and receive. It just works. Can we do the same thing for while loops? Just add the sends and receives. This time it's going to be in a loop. Seems fine. It's not fine. The problem is that this device doesn't know that this op is supposed to be run multiple times. I guess we didn't forward the frame information. SPEAKER 3: It doesn't know how many times it should run. SKYE WANDERMAN-MILNE: Well, it's going to run once or like 0 times, then you'll have-- or maybe the dead tensor will work. But if you run it once, it's just going to immediately shut down because it thinks that it has to run once, like a regular op. So the solution, you, basically, build a tiny little while loop on the other device. And so you can see there's no real data going through this computation. But it's just used through carefully placed control dependencies to drive this op as many times as you need. So this is like a whole little while loop built just to run this op n times. This while loop is indirectly driven by the real one. SPEAKER 3: It's driven by the predicate. SKYE WANDERMAN-MILNE: Yeah. Right, exactly. You can see that this guy does not have a predicate. SPEAKER 4: So we're essentially sending the predicate around for the while loop case but not doing it for the cond case. SKYE WANDERMAN-MILNE: And we build a little tiny while loop to actually use that predicate. SPEAKER 4: And essentially, if we wanted to partition into two ops, we would have to build something like this for both the cond and [INAUDIBLE].. Or it would at least look simpler, I think. SPEAKER 1: Well, the control could be centralized. SPEAKER 4: Well, you could send the predicate to other places, yes. SPEAKER 1: [INAUDIBLE] execution, yeah. SKYE WANDERMAN-MILNE: Yeah. SPEAKER 4: You would need a while loop [INAUDIBLE] device, but the predicate computation only needs to happen once. SKYE WANDERMAN-MILNE: Do we? Because we have multi-device functions, you could just call that multiple times, right? SPEAKER 4: Yeah. I mean, sure. SKYE WANDERMAN-MILNE: You won't get like parallel iterations and everything. So that's distribution. SPEAKER 6: I'm glad you speak clear. SPEAKER 3: How did the intermediate value sharing work with distribution [INAUDIBLE]? SPEAKER 1: It works the same way, except there's a lot more arrows. [LAUGHTER] Conceptually, they do not interfere with [INAUDIBLE].. But you end up with the diagram to show both at the same time would be overwhelming. SKYE WANDERMAN-MILNE: Yeah, that's a good point, though. I feel like it's not immediately obvious that it works with all the dead tensors and stuff between the forward and backwards pass. Because now you're like mixing [INAUDIBLE].. But it does somehow work. SPEAKER 4: You need to think of the intermediates as happening before you do the partitioning, and then you can see what should happen. SKYE WANDERMAN-MILNE: I'll go back to my pretty picture. Well, thanks, everyone. SPEAKER 6: Thank you. [APPLAUSE]
B1 skye speaker gradient flow tensor inaudible Inside TensorFlow: Control Flow 6 0 林宜悉 posted on 2020/03/31 More Share Save Report Video vocabulary