{-# OPTIONS --cubical-compatible --safe #-}
module Tactic.RingSolver where
open import Algebra
open import Data.Fin.Base as Fin using (Fin)
open import Data.Vec.Base as Vec using (Vec; _∷_; [])
open import Data.List.Base as List using (List; _∷_; [])
open import Data.Maybe.Base as Maybe using (Maybe; just; nothing; fromMaybe)
open import Data.Nat.Base using (ℕ; suc; zero; _<ᵇ_)
open import Data.Bool.Base using (Bool; if_then_else_; true; false)
open import Data.Unit.Base using (⊤)
open import Data.String.Base as String using (String; _++_; parens)
open import Data.Product.Base using (_,_; proj₁)
open import Function.Base
open import Relation.Nullary.Decidable
open import Reflection
open import Reflection.AST.Argument
open import Reflection.AST.Term as Term
open import Reflection.AST.AlphaEquality
open import Reflection.AST.Name as Name
open import Reflection.TCM.Syntax
open import Data.Nat.Reflection
open import Data.List.Reflection
import Data.Vec.Reflection as Vec
open import Tactic.RingSolver.NonReflective renaming (solve to solver)
open import Tactic.RingSolver.Core.AlmostCommutativeRing
open import Tactic.RingSolver.Core.NatSet as NatSet
open AlmostCommutativeRing
private
VarMap : Set
VarMap = ℕ → Maybe Term
getVisible : Arg Term → Maybe Term
getVisible (arg (arg-info visible _) x) = just x
getVisible _ = nothing
getVisibleArgs : ∀ n → Term → Maybe (Vec Term n)
getVisibleArgs n (def _ xs) = Maybe.map Vec.reverse
(List.foldl f c (List.mapMaybe getVisible xs) n)
where
f : (∀ n → Maybe (Vec Term n)) → Term → ∀ n → Maybe (Vec Term n)
f xs x zero = just []
f xs x (suc n) = Maybe.map (x ∷_) (xs n)
c : ∀ n → Maybe (Vec Term n)
c zero = just []
c (suc _ ) = nothing
getVisibleArgs _ _ = nothing
curriedTerm : NatSet → Term
curriedTerm = List.foldr go Vec.`[] ∘ NatSet.toList
where
go : ℕ → Term → Term
go x xs = var x [] Vec.`∷ xs
`AlmostCommutativeRing : Term
`AlmostCommutativeRing = def (quote AlmostCommutativeRing) (2 ⋯⟨∷⟩ [])
record RingOperatorTerms : Set where
constructor add⇒_mul⇒_pow⇒_neg⇒_sub⇒_
field
add mul pow neg sub : Term
checkIsRing : Term → TC Term
checkIsRing ring = checkType ring `AlmostCommutativeRing
module RingReflection (`ring : Term) where
infixr 6 _$ʳ_
_$ʳ_ : Name → Args Term → Term
nm $ʳ args = def nm (2 ⋯⟅∷⟆ `ring ⟨∷⟩ args)
`Carrier : Term
`Carrier = quote Carrier $ʳ []
`refl : Term
`refl = quote refl $ʳ (1 ⋯⟅∷⟆ [])
`sym : Term → Term
`sym x≈y = quote sym $ʳ (2 ⋯⟅∷⟆ x≈y ⟨∷⟩ [])
`trans : Term → Term → Term
`trans x≈y y≈z = quote trans $ʳ (3 ⋯⟅∷⟆ x≈y ⟨∷⟩ y≈z ⟨∷⟩ [])
getRingOperatorTerms : TC RingOperatorTerms
getRingOperatorTerms = ⦇
add⇒ normalise (quote _+_ $ʳ [])
mul⇒ normalise (quote _*_ $ʳ [])
pow⇒ normalise (quote _^_ $ʳ [])
neg⇒ normalise (quote (-_) $ʳ [])
sub⇒ normalise (quote _-_ $ʳ [])
⦈
module RingSolverReflection (ring : Term) (numberOfVariables : ℕ) where
open RingReflection ring
`numberOfVariables : Term
`numberOfVariables = toTerm numberOfVariables
infix -1 _$ᵉ_
_$ᵉ_ : Name → List (Arg Term) → Term
e $ᵉ xs = con e (1 ⋯⟅∷⟆ `Carrier ⟅∷⟆ `numberOfVariables ⟅∷⟆ xs)
`Κ : Term → Term
`Κ x = quote Κ $ᵉ (x ⟨∷⟩ [])
`I : Term → Term
`I x = quote Ι $ᵉ (x ⟨∷⟩ [])
infixl 6 _`⊜_
_`⊜_ : Term → Term → Term
x `⊜ y = quote _⊜_ $ʳ (`numberOfVariables ⟅∷⟆ x ⟨∷⟩ y ⟨∷⟩ [])
`correct : Term → Term → Term
`correct x ρ = quote Ops.correct $ʳ (1 ⋯⟅∷⟆ x ⟨∷⟩ ρ ⟨∷⟩ [])
`solver : Term → Term → Term
`solver `f `eq = quote solver $ʳ (`numberOfVariables ⟨∷⟩ `f ⟨∷⟩ `eq ⟨∷⟩ [])
convertTerm : RingOperatorTerms → VarMap → Term → TC Term
convertTerm operatorTerms varMap = convert
where
open RingOperatorTerms operatorTerms
mutual
convert : Term → TC Term
convert (def (quote _+_) xs) = convertOp₂ (quote _⊕_) xs
convert (def (quote _*_) xs) = convertOp₂ (quote _⊗_) xs
convert (def (quote -_) xs) = convertOp₁ (quote ⊝_) xs
convert (def (quote _^_) xs) = convertExp xs
convert (def (quote _-_) xs) = convertSub xs
convert (def nm xs) = convertUnknownName nm xs
convert v@(var x _) = pure $ fromMaybe (`Κ v) (varMap x)
convert (`suc x) = convertSuc x
convert t = pure $ `Κ t
convertOp₂ : Name → Args Term → TC Term
convertOp₂ nm (x ⟨∷⟩ y ⟨∷⟩ []) = do
x' ← convert x
y' ← convert y
pure (nm $ᵉ (x' ⟨∷⟩ y' ⟨∷⟩ []))
convertOp₂ nm (x ∷ xs) = convertOp₂ nm xs
convertOp₂ _ _ = pure unknown
convertOp₁ : Name → Args Term → TC Term
convertOp₁ nm (x ⟨∷⟩ []) = do
x' ← convert x
pure (nm $ᵉ (x' ⟨∷⟩ []))
convertOp₁ nm (x ∷ xs) = convertOp₁ nm xs
convertOp₁ _ _ = pure unknown
convertExp : Args Term → TC Term
convertExp (x ⟨∷⟩ y ⟨∷⟩ []) = do
x' ← convert x
pure (quote _⊛_ $ᵉ (x' ⟨∷⟩ y ⟨∷⟩ []))
convertExp (x ∷ xs) = convertExp xs
convertExp _ = pure unknown
convertSub : Args Term → TC Term
convertSub (x ⟨∷⟩ y ⟨∷⟩ []) = do
x' ← convert x
-y' ← convertOp₁ (quote (⊝_)) (y ⟨∷⟩ [])
pure (quote _⊕_ $ᵉ x' ⟨∷⟩ -y' ⟨∷⟩ [])
convertSub (x ∷ xs) = convertSub xs
convertSub _ = pure unknown
convertUnknownName : Name → Args Term → TC Term
convertUnknownName nm xs = do
nameTerm ← normalise (def nm [])
if (nameTerm =α= add) then convertOp₂ (quote _⊕_) xs else
if (nameTerm =α= mul) then convertOp₂ (quote _⊗_) xs else
if (nameTerm =α= neg) then convertOp₁ (quote ⊝_) xs else
if (nameTerm =α= pow) then convertExp xs else
if (nameTerm =α= sub) then convertSub xs else
pure (`Κ (def nm xs))
convertSuc : Term → TC Term
convertSuc x = do x' ← convert x; pure (quote _⊕_ $ᵉ (`Κ (toTerm 1) ⟨∷⟩ x' ⟨∷⟩ []))
open RingReflection
open RingSolverReflection
malformedForallTypeError : ∀ {a} {A : Set a} → Term → TC A
malformedForallTypeError found = typeError
( strErr "Malformed call to solve."
∷ strErr "Expected target type to be like: ∀ x y → x + y ≈ y + x."
∷ strErr "Instead: "
∷ termErr found
∷ [])
quantifiedVarMap : ℕ → VarMap
quantifiedVarMap numVars i =
if i <ᵇ numVars
then just (var i [])
else nothing
constructCallToSolver : Term → RingOperatorTerms → List String → Term → Term → TC Term
constructCallToSolver `ring opNames variables `lhs `rhs = do
`lhsExpr ← conv `lhs
`rhsExpr ← conv `rhs
pure $ `solver `ring numVars
(prependVLams variables (_`⊜_ `ring numVars `lhsExpr `rhsExpr))
(prependHLams variables (`refl `ring))
where
numVars : ℕ
numVars = List.length variables
conv : Term → TC Term
conv = convertTerm `ring numVars opNames (quantifiedVarMap numVars)
solve-∀-macro : Name → Term → TC ⊤
solve-∀-macro ring hole = do
`ring ← checkIsRing (def ring [])
commitTC
operatorTerms ← getRingOperatorTerms `ring
`hole ← inferType hole >>= reduce
let variablesAndTypes , equation = stripPis `hole
let variables = List.map proj₁ variablesAndTypes
just (lhs ∷ rhs ∷ []) ← pure (getVisibleArgs 2 equation)
where nothing → malformedForallTypeError `hole
solverCall ← constructCallToSolver `ring operatorTerms variables lhs rhs
unify hole solverCall
macro
solve-∀ : Name → Term → TC ⊤
solve-∀ = solve-∀-macro
malformedArgumentListError : ∀ {a} {A : Set a} → Term → TC A
malformedArgumentListError found = typeError
( strErr "Malformed call to solve."
∷ strErr "First argument should be a list of free variables."
∷ strErr "Instead: "
∷ termErr found
∷ [])
malformedGoalError : ∀ {a} {A : Set a} → Term → TC A
malformedGoalError found = typeError
( strErr "Malformed call to solve."
∷ strErr "Goal type should be of the form: LHS ≈ RHS"
∷ strErr "Instead: "
∷ termErr found
∷ [])
checkIsListOfVariables : Term → Term → TC Term
checkIsListOfVariables `ring `xs = checkType `xs (`List (`Carrier `ring)) >>= normalise
getVariableIndices : Term → Maybe NatSet
getVariableIndices = go []
where
go : NatSet → Term → Maybe NatSet
go t (var i [] `∷` xs) = go (insert i t) xs
go t `[]` = just t
go _ _ = nothing
constructSolution : Term → RingOperatorTerms → NatSet → Term → Term → TC Term
constructSolution `ring opTerms variables `lhs `rhs = do
`lhsExpr ← conv `lhs
`rhsExpr ← conv `rhs
pure $ `trans `ring (`sym `ring `lhsExpr) `rhsExpr
where
numVars = List.length variables
varMap : VarMap
varMap i = Maybe.map (λ x → `I `ring numVars (toFinTerm x)) (lookup variables i)
ρ : Term
ρ = curriedTerm variables
conv = λ t → do
t' ← convertTerm `ring numVars opTerms varMap t
pure $ `correct `ring numVars t' ρ
solve-macro : Term → Name → Term → TC ⊤
solve-macro variables ring hole = do
`ring ← checkIsRing (def ring [])
commitTC
operatorTerms ← getRingOperatorTerms `ring
listOfVariables′ ← checkIsListOfVariables `ring variables
commitTC
just variableIndices ← pure (getVariableIndices listOfVariables′)
where nothing → malformedArgumentListError listOfVariables′
hole′ ← inferType hole >>= reduce
just (lhs ∷ rhs ∷ []) ← pure (getVisibleArgs 2 hole′)
where nothing → malformedGoalError hole′
solution ← constructSolution `ring operatorTerms variableIndices lhs rhs
unify hole solution
macro
solve : Term → Name → Term → TC ⊤
solve = solve-macro