diff --git a/src/QsCompiler/Core/ConstructorExtensions.fs b/src/QsCompiler/Core/ConstructorExtensions.fs index 2c3a8226ec..45263ccc5a 100644 --- a/src/QsCompiler/Core/ConstructorExtensions.fs +++ b/src/QsCompiler/Core/ConstructorExtensions.fs @@ -11,39 +11,39 @@ open Microsoft.Quantum.QsCompiler.SyntaxTokens open Microsoft.Quantum.QsCompiler.SyntaxTree -type QsQualifiedName with +type QsQualifiedName with static member New (nsName, cName) = { Namespace = nsName Name = cName } -type UserDefinedType with +type UserDefinedType with static member New (nsName, tName, range) = { Namespace = nsName Name = tName Range = range } -type QsTypeParameter with +type QsTypeParameter with static member New (origin, tName, range) = { Origin = origin TypeName = tName Range = range } -type QsLocation with +type QsLocation with static member New (pos, range) = { Offset = pos Range = range } -type InferredExpressionInformation with +type InferredExpressionInformation with static member New (isMutable, quantumDep) = { - IsMutable = isMutable + IsMutable = isMutable HasLocalQuantumDependency = quantumDep } -type LocalVariableDeclaration<'Name> with +type LocalVariableDeclaration<'Name> with static member New isMutable ((pos, range), vName : 'Name, t, hasLocalQuantumDependency) = { VariableName = vName Type = t @@ -52,63 +52,63 @@ type LocalVariableDeclaration<'Name> with Range = range } -type LocalDeclarations with +type LocalDeclarations with static member New (variables : IEnumerable<_>) = { Variables = variables.ToImmutableArray() } - static member Concat this other = + static member Concat this other = LocalDeclarations.New (this.Variables.Concat other.Variables) - member this.AsVariableLookup () = + member this.AsVariableLookup () = let localVars = this.Variables |> Seq.map (fun decl -> decl.VariableName, decl) new ReadOnlyDictionary<_,_>(localVars.ToDictionary(fst, snd)) -type InferredCallableInformation with +type InferredCallableInformation with /// the default values are intrinsic: false, selfAdj: false static member New (?intrinsic, ?selfAdj) = { IsIntrinsic = defaultArg intrinsic false IsSelfAdjoint = defaultArg selfAdj false } -type CallableInformation with +type CallableInformation with static member New (characteristics, inferredInfo) = { Characteristics = characteristics InferredInformation = inferredInfo } -type TypedExpression with +type TypedExpression with /// Builds and returns a TypedExpression with the given properties. /// The UnresolvedType of the given expression is set to the given expression type, and /// the ResolvedType is set to the type constructed by resolving it using ResolveTypeParameters and the given look-up. static member New (expr, typeParamResolutions : ImmutableDictionary<_,_>, exType, exInfo, range) = { Expression = expr - TypeArguments = typeParamResolutions |> Seq.map (fun kv -> fst kv.Key, snd kv.Key, kv.Value) |> ImmutableArray.CreateRange + TypeArguments = TypedExpression.AsTypeArguments typeParamResolutions ResolvedType = ResolvedType.ResolveTypeParameters typeParamResolutions exType InferredInformation = exInfo Range = range } - -type QsBinding<'T> with + +type QsBinding<'T> with static member New kind (lhs, rhs) = { Kind = kind Lhs = lhs Rhs = rhs } -type QsValueUpdate with +type QsValueUpdate with static member New (lhs, rhs) = { Lhs = lhs Rhs = rhs } -type QsComments with +type QsComments with static member New (before : IEnumerable<_>, after : IEnumerable<_>) = { OpeningComments = before.ToImmutableArray() ClosingComments = after.ToImmutableArray() } -type QsScope with +type QsScope with static member New (statements : IEnumerable<_>, parentSymbols) = { Statements = statements.ToImmutableArray() KnownSymbols = parentSymbols @@ -121,7 +121,7 @@ type QsPositionedBlock with Comments = comments } -type QsConditionalStatement with +type QsConditionalStatement with static member New (blocks : IEnumerable<_>, defaultBlock) = { ConditionalBlocks = blocks.ToImmutableArray() Default = defaultBlock @@ -134,33 +134,33 @@ type QsForStatement with Body = body } -type QsWhileStatement with +type QsWhileStatement with static member New (condition, body) = { Condition = condition Body = body } -type QsRepeatStatement with +type QsRepeatStatement with static member New (repeatBlock, successCondition, fixupBlock) = { RepeatBlock = repeatBlock SuccessCondition = successCondition FixupBlock = fixupBlock } -type QsConjugation with +type QsConjugation with static member New (outer, inner) = { OuterTransformation = outer InnerTransformation = inner } -type QsQubitScope with +type QsQubitScope with static member New kind ((lhs,rhs), body) = { Kind = kind Binding = QsBinding.New QsBindingKind.ImmutableBinding (lhs, rhs) Body = body } -type QsStatement with +type QsStatement with static member New comments location (kind, symbolDecl) = { Statement = kind SymbolDeclarations = symbolDecl @@ -168,7 +168,7 @@ type QsStatement with Comments = comments } -type ResolvedSignature with +type ResolvedSignature with static member New ((argType, returnType), info, typeParams : IEnumerable<_>) = { TypeParameters = typeParams.ToImmutableArray() ArgumentType = argType @@ -176,7 +176,7 @@ type ResolvedSignature with Information = info } -type QsSpecialization with +type QsSpecialization with static member New kind (source, location) (parent, attributes, typeArgs, signature, implementation, documentation, comments) = { Kind = kind Parent = parent @@ -194,7 +194,7 @@ type QsSpecialization with static member NewControlled = QsSpecialization.New QsControlled static member NewControlledAdjoint = QsSpecialization.New QsControlledAdjoint -type QsCallable with +type QsCallable with static member New kind (source, location) (name, attributes, modifiers, argTuple, signature, specializations : IEnumerable<_>, documentation, comments) = { Kind = kind FullName = name @@ -225,7 +225,7 @@ type QsCustomType with Comments = comments } -type QsDeclarationAttribute with +type QsDeclarationAttribute with static member New (typeId, arg, pos, comments) = { TypeId = typeId Argument = arg @@ -234,18 +234,18 @@ type QsDeclarationAttribute with } type QsNamespaceElement with - static member NewOperation loc = QsCallable.NewOperation loc >> QsCallable + static member NewOperation loc = QsCallable.NewOperation loc >> QsCallable static member NewFunction loc = QsCallable.NewFunction loc >> QsCallable static member NewType loc = QsCustomType.New loc >> QsCustomType -type QsNamespace with +type QsNamespace with static member New (name, elements : IEnumerable<_>, documentation) = { Name = name Elements = elements.ToImmutableArray() Documentation = documentation } -type QsCompilation with +type QsCompilation with static member New (namespaces, entryPoints) = { Namespaces = namespaces EntryPoints = entryPoints diff --git a/src/QsCompiler/DataStructures/SyntaxTree.fs b/src/QsCompiler/DataStructures/SyntaxTree.fs index f41b113b48..e18926fba3 100644 --- a/src/QsCompiler/DataStructures/SyntaxTree.fs +++ b/src/QsCompiler/DataStructures/SyntaxTree.fs @@ -67,8 +67,8 @@ type QsQualifiedName = { /// the declared name of the namespace element Name : NonNullable } - with - override this.ToString () = + with + override this.ToString () = sprintf "%s.%s" this.Namespace.Value this.Name.Value @@ -344,8 +344,8 @@ type InferredExpressionInformation = { type TypedExpression = { /// the content (kind) of the expression Expression : QsExpressionKind - /// contains all type arguments implicitly or explicitly determined by the expression, - /// i.e. the origin, name and concrete type of all type parameters whose type can either be inferred based on the expression, + /// contains all type arguments implicitly or explicitly determined by the expression, + /// i.e. the origin, name and concrete type of all type parameters whose type can either be inferred based on the expression, /// or who have explicitly been resolved by provided type arguments TypeArguments : ImmutableArray * ResolvedType> /// the type of the expression after applying the type arguments @@ -361,13 +361,18 @@ type TypedExpression = { /// Contains a dictionary mapping the origin and name of all type parameters whose type can either be inferred based on the expression, /// or who have explicitly been resolved by provided type arguments to their concrete type within this expression - member this.TypeParameterResolutions = + member this.TypeParameterResolutions = this.TypeArguments.ToImmutableDictionary((fun (origin, name, _) -> origin, name), (fun (_,_,t) -> t)) + /// Given a dictionary containing the type resolutions for an expression, + /// returns the corresponding ImmutableArray to initialize the TypeArguments with. + static member AsTypeArguments (typeParamResolutions : ImmutableDictionary<_,_>) = + typeParamResolutions |> Seq.map (fun kv -> fst kv.Key, snd kv.Key, kv.Value) |> ImmutableArray.CreateRange + /// Returns true if the expression is a call-like expression, and the arguments contain a missing expression. /// Returns false otherwise. static member public IsPartialApplication kind = - let rec containsMissing ex = + let rec containsMissing ex = match ex.Expression with | MissingExpr -> true | ValueTuple items -> items |> Seq.exists containsMissing @@ -564,7 +569,7 @@ and QsStatementKind = | QsRepeatStatement of QsRepeatStatement | QsConjugation of QsConjugation | QsQubitScope of QsQubitScope // includes both using and borrowing scopes -| EmptyStatement +| EmptyStatement and QsStatement = { @@ -645,7 +650,7 @@ type QsSpecialization = { /// Contains the location information for the declared specialization. /// The position offset represents the position in the source file where the specialization is declared, /// and the range contains the range of the corresponding specialization header. - /// For auto-generated specializations, the location is set to the location of the parent callable declaration. + /// For auto-generated specializations, the location is set to the location of the parent callable declaration. Location : QsNullable /// contains the type arguments for which the implementation is specialized TypeArguments : QsNullable> diff --git a/src/QsCompiler/Tests.Compiler/TestCases/TypeParameter.qs b/src/QsCompiler/Tests.Compiler/TestCases/TypeParameter.qs new file mode 100644 index 0000000000..c91be7ec96 --- /dev/null +++ b/src/QsCompiler/Tests.Compiler/TestCases/TypeParameter.qs @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + + +// Identifier Resolution +namespace Microsoft.Quantum.Testing.TypeParameter { + + operation Main() : Unit { + Foo<_, Int, String>(1.0, 2, "Three"); + } + + operation Foo<'A, 'B, 'C>(a : 'A, b : 'B, c : 'C) : Unit { } +} + +// ================================= + +// Adjoint Application Resolution +namespace Microsoft.Quantum.Testing.TypeParameter { + + operation Main() : Unit { + (Adjoint (Foo(1.0, 2, _)))("Three"); + } + + operation Foo<'A, 'B, 'C>(a : 'A, b : 'B, c : 'C) : Unit is Adj { } +} + +// ================================= + +// Controlled Application Resolution +namespace Microsoft.Quantum.Testing.TypeParameter { + + operation Main(qs : Qubit) : Unit { + (Controlled (Foo(1.0, 2, _)))([qs], "Three"); + } + + operation Foo<'A, 'B, 'C>(a : 'A, b : 'B, c : 'C) : Unit is Ctl { } +} + +// ================================= + +// Partial Application Resolution +namespace Microsoft.Quantum.Testing.TypeParameter { + + operation Main() : Unit { + (Foo(_, 3, "Hi"))(1.0); + } + + operation Foo<'A, 'B, 'C>(a : 'A, b : 'B, c : 'C) : Unit { } +} + +// ================================= + +// Sub-call Resolution +namespace Microsoft.Quantum.Testing.TypeParameter { + + operation Main() : Unit { + (Foo(1, 2, 3))(4); + } + + operation Foo<'A, 'B, 'C>(a : 'A, b : 'B, c : 'C) : ('C => Unit) { return Bar<'A, 'B, 'C>(a, b, _); } + + operation Bar<'A, 'B, 'C>(a : 'A, b : 'B, c : 'C) : Unit { } +} + +// ================================= + +// Argument Sub-call Resolution +namespace Microsoft.Quantum.Testing.TypeParameter { + + operation Main() : Unit { + Foo(1.0, Bar(2), "Three"); + } + + operation Foo<'A, 'B, 'C>(a : 'A, b : 'B, c : 'C) : Unit { } + + operation Bar<'A>(a : 'A) : 'A { return a; } +} diff --git a/src/QsCompiler/Tests.Compiler/Tests.Compiler.fsproj b/src/QsCompiler/Tests.Compiler/Tests.Compiler.fsproj index 03fc1fbdff..e8234d9639 100644 --- a/src/QsCompiler/Tests.Compiler/Tests.Compiler.fsproj +++ b/src/QsCompiler/Tests.Compiler/Tests.Compiler.fsproj @@ -145,6 +145,9 @@ PreserveNewest + + PreserveNewest + @@ -158,6 +161,7 @@ + diff --git a/src/QsCompiler/Tests.Compiler/TypeParameterTests.fs b/src/QsCompiler/Tests.Compiler/TypeParameterTests.fs new file mode 100644 index 0000000000..97ac80ea44 --- /dev/null +++ b/src/QsCompiler/Tests.Compiler/TypeParameterTests.fs @@ -0,0 +1,621 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Microsoft.Quantum.QsCompiler.Testing + +open System +open System.Collections.Immutable +open System.IO +open Microsoft.Quantum.QsCompiler +open Microsoft.Quantum.QsCompiler.CompilationBuilder +open Microsoft.Quantum.QsCompiler.DataTypes +open Microsoft.Quantum.QsCompiler.SyntaxExtensions +open Microsoft.Quantum.QsCompiler.SyntaxTokens +open Microsoft.Quantum.QsCompiler.SyntaxTree +open Xunit + + +type TypeParameterTests () = + + let TypeParameterNS = "Microsoft.Quantum.Testing.TypeParameter" + + let qualifiedName name = + (TypeParameterNS |> NonNullable.New, name |> NonNullable.New) |> QsQualifiedName.New + + let typeParameter (id : string) = + let pieces = id.Split(".") + Assert.True(pieces.Length = 2) + let parent = qualifiedName pieces.[0] + let name = pieces.[1] |> NonNullable.New + QsTypeParameter.New (parent, name, Null) + + let FooA = typeParameter "Foo.A" + let FooB = typeParameter "Foo.B" + let FooC = typeParameter "Foo.C" + let BarA = typeParameter "Bar.A" + let BarB = typeParameter "Bar.B" + let BazA = typeParameter "Baz.A" + + let MakeTupleType types = + types |> Seq.map ResolvedType.New |> ImmutableArray.CreateRange |> TupleType + + let ResolutionFromParam (res : (QsTypeParameter * QsTypeKind<_,_,_,_>) list) = + res.ToImmutableDictionary((fun (tp,_) -> tp.Origin, tp.TypeName), snd >> ResolvedType.New) + + let CheckResolutionMatch (res1 : ImmutableDictionary<_,_>) (res2 : ImmutableDictionary<_,_> ) = + let keysMismatch = ImmutableHashSet.CreateRange(res1.Keys).SymmetricExcept res2.Keys + keysMismatch.Count = 0 && res1 |> Seq.exists (fun kv -> res2.[kv.Key] <> kv.Value) |> not + + let AssertExpectedResolution expected given = + Assert.True(CheckResolutionMatch expected given, "Given resolutions did not match the expected resolutions.") + + let CheckCombinedResolution expected (resolutions : ImmutableDictionary<(QsQualifiedName*NonNullable),ResolvedType> []) = + let combination = TypeResolutionCombination(resolutions) + AssertExpectedResolution expected combination.CombinedResolutionDictionary + combination.IsValid + + let AssertCombinedResolution expected resolutions = + let success = CheckCombinedResolution expected resolutions + Assert.True(success, "Combining type resolutions was not successful.") + + let AssertCombinedResolutionFailure expected resolutions = + let success = CheckCombinedResolution expected resolutions + Assert.False(success, "Combining type resolutions should have failed.") + + let compilationManager = new CompilationUnitManager(new Action (fun ex -> failwith ex.Message)) + + let getTempFile () = new Uri(Path.GetFullPath(Path.GetRandomFileName())) + let getManager uri content = CompilationUnitManager.InitializeFileManager(uri, content, compilationManager.PublishDiagnostics, compilationManager.LogException) + + let ReadAndChunkSourceFile fileName = + let sourceInput = Path.Combine ("TestCases", fileName) |> File.ReadAllText + sourceInput.Split ([|"==="|], StringSplitOptions.RemoveEmptyEntries) + + let BuildContent content = + + let fileId = getTempFile() + let file = getManager fileId content + + compilationManager.AddOrUpdateSourceFileAsync(file) |> ignore + let compilationDataStructures = compilationManager.Build() + compilationManager.TryRemoveSourceFileAsync(fileId, false) |> ignore + + compilationDataStructures.Diagnostics() |> Seq.exists (fun d -> d.IsError()) |> Assert.False + Assert.NotNull compilationDataStructures.BuiltCompilation + + compilationDataStructures + + let CompileTypeParameterTest testNumber = + let srcChunks = ReadAndChunkSourceFile "TypeParameter.qs" + srcChunks.Length >= testNumber |> Assert.True + let compilationDataStructures = BuildContent <| srcChunks.[testNumber-1] + let processedCompilation = compilationDataStructures.BuiltCompilation + Assert.NotNull processedCompilation + processedCompilation + + let GetCallableWithName compilation ns name = + compilation.Namespaces + |> Seq.filter (fun x -> x.Name.Value = ns) + |> GlobalCallableResolutions + |> Seq.find (fun x -> x.Key.Name.Value = name) + |> (fun x -> x.Value) + + let GetMainExpression (compilation : QsCompilation) = + let mainCallable = GetCallableWithName compilation TypeParameterNS "Main" + let body = + mainCallable.Specializations + |> Seq.find (fun x -> x.Kind = QsSpecializationKind.QsBody) + |> fun x -> match x.Implementation with + | Provided (_, body) -> body + | _ -> failwith "Expected but did not find Provided Implementation" + Assert.True(body.Statements.Length = 1) + match body.Statements.[0].Statement with + | QsExpressionStatement expression -> expression + | _ -> failwith "Expected but did not find an Expression Statement" + + + [] + [] + member this.``Resolution to Concrete`` () = + + let given = [| + ResolutionFromParam [ + (FooA, BarA |> TypeParameter) + (FooB, Int) + ] + ResolutionFromParam [ + (BarA, Int) + ] + |] + let expected = ResolutionFromParam [ + (FooA, Int) + (FooB, Int) + (BarA, Int) + ] + + AssertCombinedResolution expected given + + [] + [] + member this.``Resolution to Type Parameter`` () = + + let given = [| + ResolutionFromParam [ + (FooA, BarA |> TypeParameter) + ] + ResolutionFromParam [ + (BarA, BazA |> TypeParameter) + ] + |] + let expected = ResolutionFromParam [ + (FooA, BazA |> TypeParameter) + (BarA, BazA |> TypeParameter) + ] + + AssertCombinedResolution expected given + + [] + [] + member this.``Resolution via Identity Mapping`` () = + + let given = [| + ResolutionFromParam [ + (FooA, FooA |> TypeParameter) + ] + ResolutionFromParam [ + (FooA, Int) + ] + |] + let expected = ResolutionFromParam [ + (FooA, Int) + ] + + AssertCombinedResolution expected given + + [] + [] + member this.``Multi-Stage Resolution`` () = + + let given = [| + ResolutionFromParam [ + (FooA, BarA |> TypeParameter) + ] + ResolutionFromParam [ + (BarA, Int) + (FooB, Int) + ] + |] + let expected = ResolutionFromParam [ + (FooA, Int) + (FooB, Int) + (BarA, Int) + ] + + AssertCombinedResolution expected given + + [] + [] + member this.``Multiple Resolutions to Concrete`` () = + + let given = [| + ResolutionFromParam [ + (FooA, BarA |> TypeParameter) + (FooB, BarA |> TypeParameter) + ] + ResolutionFromParam [ + (BarA, Int) + ] + |] + let expected = ResolutionFromParam [ + (FooA, Int) + (FooB, Int) + (BarA, Int) + ] + + AssertCombinedResolution expected given + + [] + [] + member this.``Multiple Resolutions to Type Parameter`` () = + + let given = [| + ResolutionFromParam [ + (FooA, BarA |> TypeParameter) + (FooB, BarA |> TypeParameter) + ] + ResolutionFromParam [ + (BarA, BazA |> TypeParameter) + ] + ResolutionFromParam [ + (BazA, Double) + ] + |] + let expected = ResolutionFromParam [ + (FooA, Double) + (FooB, Double) + (BarA, Double) + (BazA, Double) + ] + + AssertCombinedResolution expected given + + [] + [] + member this.``Multi-Stage Resolution of Multiple Resolutions to Concrete`` () = + + let given = [| + ResolutionFromParam [ + (FooA, BarA |> TypeParameter) + ] + ResolutionFromParam [ + (FooB, BarA |> TypeParameter) + ] + ResolutionFromParam [ + (BarA, Int) + ] + |] + let expected = ResolutionFromParam [ + (FooA, Int) + (FooB, Int) + (BarA, Int) + ] + + AssertCombinedResolution expected given + + [] + [] + member this.``Multi-Stage Resolution of Multiple Resolutions to Type Parameter`` () = + + let given = [| + ResolutionFromParam [ + (FooA, BarA |> TypeParameter) + ] + ResolutionFromParam [ + (BarA, BazA |> TypeParameter) + (FooB, BarA |> TypeParameter) + ] + ResolutionFromParam [ + (BazA, Int) + ] + |] + let expected = ResolutionFromParam [ + (FooA, Int) + (FooB, Int) + (BarA, Int) + (BazA, Int) + ] + + AssertCombinedResolution expected given + + [] + [] + member this.``Redundant Resolution to Concrete`` () = + + let given = [| + ResolutionFromParam [ + (FooA, BarA |> TypeParameter) + ] + ResolutionFromParam [ + (BarA, Int) + (FooA, Int) + ] + |] + let expected = ResolutionFromParam [ + (FooA, Int) + (BarA, Int) + ] + + AssertCombinedResolution expected given + + [] + [] + member this.``Redundant Resolution to Type Parameter`` () = + + let given = [| + ResolutionFromParam [ + (FooA, BarA |> TypeParameter) + ] + ResolutionFromParam [ + (BarA, BazA |> TypeParameter) + ] + ResolutionFromParam [ + (FooA, BazA |> TypeParameter) + ] + |] + let expected = ResolutionFromParam [ + (FooA, BazA |> TypeParameter) + (BarA, BazA |> TypeParameter) + ] + + AssertCombinedResolution expected given + + [] + [] + member this.``Conflicting Resolution to Concrete`` () = + + let given = [| + ResolutionFromParam [ + (FooA, BarA |> TypeParameter) + ] + ResolutionFromParam [ + (BarA, Int) + (FooA, Double) + ] + |] + let expected = ResolutionFromParam [ + (FooA, Double) + (BarA, Int) + ] + + AssertCombinedResolutionFailure expected given + + [] + [] + member this.``Conflicting Resolution to Type Parameter`` () = + + let given = [| + ResolutionFromParam [ + (FooA, BarA |> TypeParameter) + ] + ResolutionFromParam [ + (BarA, BazA |> TypeParameter) + (FooA, BarB |> TypeParameter) + ] + |] + let expected = ResolutionFromParam [ + (FooA, BarB |> TypeParameter) + (BarA, BazA |> TypeParameter) + ] + + AssertCombinedResolutionFailure expected given + + [] + [] + member this.``Direct Resolution to Native`` () = + + let given = [| + ResolutionFromParam [ + (FooA, FooA |> TypeParameter) + ] + |] + let expected = ResolutionFromParam [ + (FooA, FooA |> TypeParameter) + ] + + AssertCombinedResolution expected given + + [] + [] + member this.``Indirect Resolution to Native`` () = + + let given = [| + ResolutionFromParam [ + (FooA, BarA |> TypeParameter) + ] + ResolutionFromParam [ + (BarA, BazA |> TypeParameter) + ] + ResolutionFromParam [ + (BazA, FooA |> TypeParameter) + ] + |] + let expected = ResolutionFromParam [ + (FooA, FooA |> TypeParameter) + (BarA, FooA |> TypeParameter) + (BazA, FooA |> TypeParameter) + ] + + AssertCombinedResolution expected given + + [] + [] + member this.``Direct Resolution Constrains Native`` () = + + let given = [| + ResolutionFromParam [ + (FooA, FooB |> TypeParameter) + ] + |] + let expected = ResolutionFromParam [ + (FooA, FooB |> TypeParameter) + ] + + AssertCombinedResolutionFailure expected given + + [] + [] + member this.``Indirect Resolution Constrains Native`` () = + + let given = [| + ResolutionFromParam [ + (FooA, BarA |> TypeParameter) + ] + ResolutionFromParam [ + (BarA, BazA |> TypeParameter) + ] + ResolutionFromParam [ + (BazA, FooB |> TypeParameter) + ] + |] + let expected = ResolutionFromParam [ + (FooA, FooB |> TypeParameter) + (BarA, FooB |> TypeParameter) + (BazA, FooB |> TypeParameter) + ] + + AssertCombinedResolutionFailure expected given + + [] + [] + member this.``Inner Cycle Constrains Type Parameter`` () = + + let given = [| + ResolutionFromParam [ + (FooA, BarA |> TypeParameter) + ] + ResolutionFromParam [ + (BarA, BazA |> TypeParameter) + ] + ResolutionFromParam [ + (BazA, BarB |> TypeParameter) + ] + |] + let expected = ResolutionFromParam [ + (FooA, BarB |> TypeParameter) + (BarA, BarB |> TypeParameter) + (BazA, BarB |> TypeParameter) + ] + + AssertCombinedResolutionFailure expected given + + [] + [] + member this.``Nested Type Paramter Resolution`` () = + let given = [| + ResolutionFromParam [ + (FooA, [BarA |> TypeParameter; Int] |> MakeTupleType) + ] + ResolutionFromParam [ + (BarA, [String; Double] |> MakeTupleType) + ] + |] + let expected = ResolutionFromParam [ + (FooA, [[String; Double] |> MakeTupleType; Int] |> MakeTupleType) + (BarA, [String; Double] |> MakeTupleType) + ] + + AssertCombinedResolution expected given + + [] + [] + member this.``Nested Constricted Resolution`` () = + let given = [| + ResolutionFromParam [ + (FooA, [FooB |> TypeParameter; Int] |> MakeTupleType) + ] + |] + let expected = ResolutionFromParam [ + (FooA, [FooB |> TypeParameter; Int] |> MakeTupleType) + ] + + AssertCombinedResolutionFailure expected given + + [] + [] + member this.``Nested Self Resolution`` () = + let given = [| + ResolutionFromParam [ + (FooA, [FooA |> TypeParameter; BarA |> TypeParameter] |> MakeTupleType) + ] + ResolutionFromParam [ + (BarA, Int) + ] + |] + let expected = ResolutionFromParam [ + (FooA, [FooA |> TypeParameter; Int] |> MakeTupleType) + (BarA, Int) + ] + + AssertCombinedResolution expected given + + [] + [] + member this.``Single Dictonary Resolution`` () = + let given = [| + ResolutionFromParam [ + (FooA, BarA |> TypeParameter) + (BarA, Int) + ] + |] + let expected = ResolutionFromParam [ + (FooA, Int) + (BarA, Int) + ] + + AssertCombinedResolution expected given + + [] + [] + member this.``Identifier Resolution`` () = + let expression = CompileTypeParameterTest 1 |> GetMainExpression + + let combination = TypeResolutionCombination(expression) + let given = combination.CombinedResolutionDictionary + let expected = ResolutionFromParam [ + (FooA, Double) + (FooB, Int) + (FooC, String) + ] + + AssertExpectedResolution expected given + + [] + [] + member this.``Adjoint Application Resolution`` () = + let expression = CompileTypeParameterTest 2 |> GetMainExpression + + let combination = TypeResolutionCombination(expression) + let given = combination.CombinedResolutionDictionary + let expected = ResolutionFromParam [ + (FooA, Double) + (FooB, Int) + (FooC, String) + ] + + AssertExpectedResolution expected given + + [] + [] + member this.``Controlled Application Resolution`` () = + let expression = CompileTypeParameterTest 3 |> GetMainExpression + + let combination = TypeResolutionCombination(expression) + let given = combination.CombinedResolutionDictionary + let expected = ResolutionFromParam [ + (FooA, Double) + (FooB, Int) + (FooC, String) + ] + + AssertExpectedResolution expected given + + [] + [] + member this.``Partial Application Resolution`` () = + let expression = CompileTypeParameterTest 4 |> GetMainExpression + + let combination = TypeResolutionCombination(expression) + let given = combination.CombinedResolutionDictionary + let expected = ResolutionFromParam [ + (FooA, Double) + (FooB, Int) + (FooC, String) + ] + + AssertExpectedResolution expected given + + [] + [] + member this.``Sub-call Resolution`` () = + let expression = CompileTypeParameterTest 5 |> GetMainExpression + + let combination = TypeResolutionCombination(expression) + let given = combination.CombinedResolutionDictionary + let expected = ResolutionFromParam [ ] + + AssertExpectedResolution expected given + + [] + [] + member this.``Argument Sub-call Resolution`` () = + let expression = CompileTypeParameterTest 6 |> GetMainExpression + + let combination = TypeResolutionCombination(expression) + let given = combination.CombinedResolutionDictionary + let expected = ResolutionFromParam [ + (FooA, Double) + (FooB, Int) + (FooC, String) + ] + + AssertExpectedResolution expected given diff --git a/src/QsCompiler/Transformations/ClassicallyControlled.cs b/src/QsCompiler/Transformations/ClassicallyControlled.cs index d7018f6f93..ff15c9740e 100644 --- a/src/QsCompiler/Transformations/ClassicallyControlled.cs +++ b/src/QsCompiler/Transformations/ClassicallyControlled.cs @@ -249,29 +249,6 @@ public StatementTransformation(SyntaxTreeTransformation par { } - /// - /// Get the combined type resolutions for a pair of nested resolutions, - /// resolving references in the inner resolutions to the outer resolutions. - /// - private TypeArgsResolution GetCombinedTypeResolution(TypeArgsResolution outer, TypeArgsResolution inner) - { - var outerDict = outer.ToDictionary(x => (x.Item1, x.Item2), x => x.Item3); - return inner.Select(innerRes => - { - if (innerRes.Item3.Resolution is ResolvedTypeKind.TypeParameter typeParam && - outerDict.TryGetValue((typeParam.Item.Origin, typeParam.Item.TypeName), out var outerRes)) - { - outerDict.Remove((typeParam.Item.Origin, typeParam.Item.TypeName)); - return Tuple.Create(innerRes.Item1, innerRes.Item2, outerRes); - } - else - { - return innerRes; - } - }) - .Concat(outerDict.Select(x => Tuple.Create(x.Key.Item1, x.Key.Item2, x.Value))).ToImmutableArray(); - } - /// /// Checks if the scope is valid for conversion to an operation call from the conditional control API. /// It is valid if there is exactly one statement in it and that statement is a call like expression statement. @@ -289,25 +266,26 @@ private TypeArgsResolution GetCombinedTypeResolution(TypeArgsResolution outer, T && !TypedExpression.IsPartialApplication(expr.Item.Expression) && call.Item1.Expression is ExpressionKind.Identifier) { - // We are dissolving the application of arguments here, so the call's type argument - // resolutions have to be moved to the 'identifier' sub expression. - - var callTypeArguments = expr.Item.TypeArguments; - var idTypeArguments = call.Item1.TypeArguments; - var combinedTypeArguments = this.GetCombinedTypeResolution(callTypeArguments, idTypeArguments); - - // This relies on global callables being the only things that have type parameters. var newCallIdentifier = call.Item1; - if (combinedTypeArguments.Any() - && newCallIdentifier.Expression is ExpressionKind.Identifier id - && id.Item1 is Identifier.GlobalCallable global) + var callTypeArguments = expr.Item.TypeParameterResolutions; + + // This relies on anything having type parameters must be a global callable. + if (newCallIdentifier.Expression is ExpressionKind.Identifier id + && id.Item1 is Identifier.GlobalCallable global + && callTypeArguments.Any()) { + // We are dissolving the application of arguments here, so the call's type argument + // resolutions have to be moved to the 'identifier' sub expression. + var combination = new TypeResolutionCombination(expr.Item); + var combinedTypeArguments = combination.CombinedResolutionDictionary.Where(kvp => kvp.Key.Item1.Equals(global.Item)).ToImmutableDictionary(); + QsCompilerError.Verify(combination.IsValid, "failed to combine type parameter resolution"); + var globalCallable = this.SharedState.Compilation.Namespaces .Where(ns => ns.Name.Equals(global.Item.Namespace)) .Callables() .FirstOrDefault(c => c.FullName.Name.Equals(global.Item.Name)); - QsCompilerError.Verify(globalCallable != null, $"Could not find the global reference {global.Item.Namespace.Value + "." + global.Item.Name.Value}"); + QsCompilerError.Verify(globalCallable != null, $"Could not find the global reference {global.Item}."); var callableTypeParameters = globalCallable.Signature.TypeParameters .Select(x => x as QsLocalSymbol.ValidName); @@ -319,8 +297,8 @@ private TypeArgsResolution GetCombinedTypeResolution(TypeArgsResolution outer, T id.Item1, QsNullable>.NewValue( callableTypeParameters - .Select(x => combinedTypeArguments.First(y => y.Item2.Equals(x.Item)).Item3).ToImmutableArray())), - combinedTypeArguments, + .Select(x => combinedTypeArguments[Tuple.Create(global.Item, x.Item)]).ToImmutableArray())), + TypedExpression.AsTypeArguments(combinedTypeArguments), call.Item1.ResolvedType, call.Item1.InferredInformation, call.Item1.Range); diff --git a/src/QsCompiler/Transformations/TypeResolutionCombination.cs b/src/QsCompiler/Transformations/TypeResolutionCombination.cs new file mode 100644 index 0000000000..88f210dfe5 --- /dev/null +++ b/src/QsCompiler/Transformations/TypeResolutionCombination.cs @@ -0,0 +1,397 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using Microsoft.Quantum.QsCompiler.DataTypes; +using Microsoft.Quantum.QsCompiler.SyntaxTokens; +using Microsoft.Quantum.QsCompiler.SyntaxTree; +using Microsoft.Quantum.QsCompiler.Transformations.Core; + +#nullable enable + +namespace Microsoft.Quantum.QsCompiler +{ + using ExpressionKind = QsExpressionKind; + using ResolvedTypeKind = QsTypeKind; + // Type Parameters are frequently referenced by the callable of the type parameter followed by the name of the specific type parameter. + using TypeParameterName = Tuple>; + using TypeParameterResolutions = ImmutableDictionary>, ResolvedType>; + + /// + /// Combines a series of type parameter resolution dictionaries, IndependentResolutionDictionaries, + /// into one resolution dictionary, CombinedResolutionDictionary, containing the ultimate type + /// resolutions for all the type parameters found in the dictionaries. Validation is done on the + /// resolutions, which can be checked through the IsValid flag. + /// + public class TypeResolutionCombination + { + // Static Members + + /// + /// Checks if the given type parameter directly resolves to itself. + /// + private static bool IsSelfResolution(TypeParameterName typeParam, ResolvedType res) + { + return res.Resolution is ResolvedTypeKind.TypeParameter tp + && tp.Item.Origin.Equals(typeParam.Item1) + && tp.Item.TypeName.Equals(typeParam.Item2); + } + + /// + /// Reverses the dependencies of type parameters resolving to other type parameters in the given + /// dictionary to create a lookup whose keys are type parameters and whose values are all the type + /// parameters that can be updated by knowing the resolution of the lookup's associated key. + /// + private static ILookup GetReplaceable(TypeParameterResolutions.Builder typeParamResolutions) + { + return typeParamResolutions + .Select(kvp => (kvp.Key, GetTypeParameters.Apply(kvp.Value))) // Get any type parameters in the resolution type. + .SelectMany(tup => tup.Item2.Select(value => (tup.Key, value))) // For each type parameter found, match it to the dictionary key. + .ToLookup(// Reverse the keys and resulting type parameters to make the lookup. + kvp => kvp.value, + kvp => kvp.Key); + } + + // Fields and Properties + + /// + /// Array of all the type parameter resolution dictionaries that are combined in this combination. + /// The items are ordered such that dictionaries containing type parameters resolutions that + /// reference type parameters in other dictionaries appear before those dictionaries containing + /// the referenced type parameters. I.e., dictionary A depends on dictionary B, so A should come before B. + /// + public readonly ImmutableArray IndependentResolutionDictionaries; + + /// + /// The resulting resolution dictionary from combining all the input resolutions in + /// IndependentResolutionDictionaries. Represents a combination of all the type parameters + /// found and their ultimate type resolutions. + /// + public TypeParameterResolutions CombinedResolutionDictionary { get; private set; } = TypeParameterResolutions.Empty; + + /// + /// Flag for if there were any invalid scenarios encountered while creating the combination. + /// Invalid scenarios include a type parameter being assigned to multiple conflicting types + /// and a type parameter being assigned to a type referencing a different type parameter of + /// the same callable. Has value true if no invalid scenarios were encountered. + /// + public bool IsValid => !this.combinesOverConflictingResolution && !this.combinesOverParameterConstriction; + + /// + /// Flag for if, at any time in the creation of the combination, there was a type parameter that + /// was assigned conflicting type resolutions. Has value true if a conflict was encountered. + /// + private bool combinesOverConflictingResolution = false; + + /// + /// Flag for if, at any time in the creation of the combination, there was a type parameter that + /// was assigned a type resolution referencing a different type parameter of the same callable. + /// + private bool combinesOverParameterConstriction = false; + + // Constructors + + /// + /// Creates a type parameter resolution combination from the independent type parameter resolutions + /// found in the given typed expression and its sub expressions. Only sub-expressions whose + /// type parameter resolutions are relevant to the given expression's type parameter resolutions + /// are considered. + /// + public TypeResolutionCombination(TypedExpression expression) : this(GetTypeParameterResolutions.Apply(expression)) + { + } + + /// + /// Creates a type parameter resolution combination from independent type parameter resolution dictionaries. + /// The given resolutions are expected to be ordered such that dictionaries containing type parameters resolutions that + /// reference type parameters in other dictionaries appear before those dictionaries containing the referenced type parameters. + /// I.e., dictionary A depends on dictionary B, so A should come before B. When using this method to resolve + /// the resolutions of a nested expression, this means that the innermost resolutions should come first, followed by + /// the next innermost, and so on until the outermost expression is given last. Empty and null dictionaries are ignored. + /// + internal TypeResolutionCombination(params TypeParameterResolutions[] independentResolutionDictionaries) + { + // Filter out empty dictionaries + this.IndependentResolutionDictionaries = independentResolutionDictionaries.Where(res => !(res is null || res.IsEmpty)).ToImmutableArray(); + + if (this.IndependentResolutionDictionaries.Any()) + { + this.CombineTypeResolutions(); + } + } + + // Methods + + /// + /// Updates the combinesOverParameterConstriction flag. If the flag is already set to true, + /// nothing will be done. If not, the given type parameter will be checked against the given + /// resolution for type parameter constriction, which is when one type parameter is dependent + /// on another type parameter of the same callable. + /// + private void UpdateConstrictionFlag(TypeParameterName typeParamName, ResolvedType typeParamResolution) + { + this.combinesOverParameterConstriction = this.combinesOverParameterConstriction + || CheckForConstriction.Apply(typeParamName, typeParamResolution); + } + + /// + /// Uses the given lookup, mayBeReplaced, to determine what records in the combinedBuilder can be updated + /// from the given type parameter, typeParam, and its resolution, paramRes. Then updates the combinedBuilder + /// appropriately. + /// + private void UpdatedReplaceableResolutions( + ILookup mayBeReplaced, + TypeParameterResolutions.Builder combinedBuilder, + TypeParameterName typeParam, + ResolvedType paramRes) + { + // Create a dictionary with just the current resolution in it. + var singleResolution = new[] { 0 }.ToImmutableDictionary(_ => typeParam, _ => paramRes); + + // Get all the parameters whose value is dependent on the current resolution's type parameter, + // and update their values with this resolution's value. + foreach (var keyInCombined in mayBeReplaced[typeParam]) + { + // Check that we are not constricting a type parameter to another type parameter of the same callable. + this.UpdateConstrictionFlag(keyInCombined, paramRes); + combinedBuilder[keyInCombined] = ResolvedType.ResolveTypeParameters(singleResolution, combinedBuilder[keyInCombined]); + } + } + + /// + /// Combines independent resolutions in a disjointed dictionary, resulting in a + /// resolution dictionary that has type parameter keys that are not referenced + /// in its values. Null mappings are removed in the resulting dictionary. + /// Returns the resulting dictionary. + /// + private TypeParameterResolutions CombineTypeResolutionDictionary(TypeParameterResolutions independentResolutions) + { + var combinedBuilder = ImmutableDictionary.CreateBuilder(); + + foreach (var (typeParam, paramRes) in independentResolutions) + { + // Skip any null mappings + if (paramRes is null) + { + continue; + } + + // Contains a lookup of all the keys in the combined resolutions whose value needs to be updated + // if a certain type parameter is resolved by the currently processed dictionary. + var mayBeReplaced = GetReplaceable(combinedBuilder); + + // Check that we are not constricting a type parameter to another type parameter of the same callable + // both before and after updating the current value with the resolutions processed so far. + this.UpdateConstrictionFlag(typeParam, paramRes); + var resolvedParamRes = ResolvedType.ResolveTypeParameters(combinedBuilder.ToImmutable(), paramRes); + this.UpdateConstrictionFlag(typeParam, resolvedParamRes); + + // Do any replacements for type parameters that may be replaced with the current resolution. + this.UpdatedReplaceableResolutions(mayBeReplaced, combinedBuilder, typeParam, resolvedParamRes); + + // Add the resolution to the current dictionary. + combinedBuilder[typeParam] = resolvedParamRes; + } + + return combinedBuilder.ToImmutable(); + } + + /// + /// Combines the resolution dictionaries in the combination into one resolution dictionary containing + /// the resolutions for all the type parameters found. + /// Updates the combination with the constructed dictionary. Updates the validation flags accordingly. + /// + private void CombineTypeResolutions() + { + var combinedBuilder = ImmutableDictionary.CreateBuilder(); + + foreach (var resolutionDictionary in this.IndependentResolutionDictionaries) + { + var resolvedDictionary = this.CombineTypeResolutionDictionary(resolutionDictionary); + + // Contains a lookup of all the keys in the combined resolutions whose value needs to be updated + // if a certain type parameter is resolved by the currently processed dictionary. + var mayBeReplaced = GetReplaceable(combinedBuilder); + + // Do any replacements for type parameters that may be replaced with values in the current dictionary. + // This needs to be done first to cover an edge case. + foreach (var (typeParam, paramRes) in resolvedDictionary.Where(entry => mayBeReplaced.Contains(entry.Key))) + { + this.UpdatedReplaceableResolutions(mayBeReplaced, combinedBuilder, typeParam, paramRes); + } + + // Validate and add each resolution to the result. + foreach (var (typeParam, paramRes) in resolvedDictionary) + { + // Check that we are not constricting a type parameter to another type parameter of the same callable. + this.UpdateConstrictionFlag(typeParam, paramRes); + + // Check that there is no conflicting resolution already defined. + if (!this.combinesOverConflictingResolution) + { + this.combinesOverConflictingResolution = combinedBuilder.TryGetValue(typeParam, out var current) + && !current.Equals(paramRes) && !IsSelfResolution(typeParam, current); + } + + // Add the resolution to the current dictionary. + combinedBuilder[typeParam] = paramRes; + } + } + + this.CombinedResolutionDictionary = combinedBuilder.ToImmutable(); + } + + // Nested Classes + + /// + /// Walker that collects all of the type parameter references for a given ResolvedType + /// and returns them as a HashSet. + /// + private class GetTypeParameters : TypeTransformation + { + /// + /// Walks the given ResolvedType and returns all of the type parameters referenced. + /// + public static HashSet Apply(ResolvedType res) + { + var walker = new GetTypeParameters(); + walker.OnType(res); + return walker.SharedState.TypeParams; + } + + internal class TransformationState + { + public HashSet TypeParams = new HashSet(); + } + + private GetTypeParameters() : base(new TransformationState(), TransformationOptions.NoRebuild) + { + } + + private static TypeParameterName AsTypeResolutionKey(QsTypeParameter tp) => Tuple.Create(tp.Origin, tp.TypeName); + + public override ResolvedTypeKind OnTypeParameter(QsTypeParameter tp) + { + this.SharedState.TypeParams.Add(AsTypeResolutionKey(tp)); + return base.OnTypeParameter(tp); + } + } + + /// + /// Walker that checks a given type parameter resolution to see if it constricts + /// the type parameter to another type parameter of the same callable. + /// + private class CheckForConstriction : TypeTransformation + { + private readonly TypeParameterName typeParamName; + + /// + /// Walks the given ResolvedType, typeParamRes, and returns true if there is a reference + /// to a different type parameter of the same callable as the given type parameter, typeParam. + /// Otherwise returns false. + /// + public static bool Apply(TypeParameterName typeParam, ResolvedType typeParamRes) + { + var walker = new CheckForConstriction(typeParam); + walker.OnType(typeParamRes); + return walker.SharedState.IsConstrictive; + } + + internal class TransformationState + { + public bool IsConstrictive = false; + } + + private CheckForConstriction(TypeParameterName typeParamName) + : base(new TransformationState(), TransformationOptions.NoRebuild) + { + this.typeParamName = typeParamName; + } + + public new ResolvedType OnType(ResolvedType t) + { + // Short-circuit if we already know the type is constrictive. + if (!this.SharedState.IsConstrictive) + { + base.OnType(t); + } + + // It doesn't matter what we return because this is a walker. + return t; + } + + public override ResolvedTypeKind OnTypeParameter(QsTypeParameter tp) + { + // If the type parameter is from the same callable, but is a different parameter, + // then the type resolution is constrictive. + if (tp.Origin.Equals(this.typeParamName.Item1) && !tp.TypeName.Equals(this.typeParamName.Item2)) + { + this.SharedState.IsConstrictive = true; + } + + return base.OnTypeParameter(tp); + } + } + + /// + /// Walker that returns the relevant type parameter resolution dictionaries from a given + /// TypedExpression and its sub-expressions. + /// + private class GetTypeParameterResolutions : ExpressionTransformation + { + /// + /// Walk the given TypedExpression, collecting type parameter resolution dictionaries relevant to + /// the type parameter resolutions of the topmost expression. Returns the resolution dictionaries + /// ordered from the innermost expression's resolutions to the outermost expression's resolutions. + /// + public static TypeParameterResolutions[] Apply(TypedExpression expression) + { + var walker = new GetTypeParameterResolutions(); + walker.OnTypedExpression(expression); + return walker.SharedState.Resolutions.ToArray(); + } + + internal class TransformationState + { + public List Resolutions = new List(); + public bool InCallLike = false; + } + + private GetTypeParameterResolutions() : base(new TransformationState(), TransformationOptions.NoRebuild) + { + } + + public override TypedExpression OnTypedExpression(TypedExpression ex) + { + if (ex.Expression is ExpressionKind.CallLikeExpression call) + { + if (!this.SharedState.InCallLike || TypedExpression.IsPartialApplication(call)) + { + var contextInCallLike = this.SharedState.InCallLike; + this.SharedState.InCallLike = true; + this.OnTypedExpression(call.Item1); + this.SharedState.Resolutions.Add(ex.TypeParameterResolutions); + this.SharedState.InCallLike = contextInCallLike; + } + } + else if (ex.Expression is ExpressionKind.AdjointApplication adj) + { + this.OnTypedExpression(adj.Item); + } + else if (ex.Expression is ExpressionKind.ControlledApplication ctrl) + { + this.OnTypedExpression(ctrl.Item); + } + else + { + this.SharedState.Resolutions.Add(ex.TypeParameterResolutions); + } + + return ex; + } + } + } +}